import os, h5py, torch, json
from tqdm import tqdm
from PIL import Image
from torchvision.transforms import ToPILImage
to_pil = ToPILImage()

# =================================================================
# 1. 配置参数 (请根据你的环境修改)
# =================================================================
device = "cuda:0"
data_path = "dataset"                       # 修改为你的 mindeye2_src 目录
hdf5_file = f"{data_path}/coco_images_224_float16.hdf5"
output_dir = "./stimuli_sets_73k"
os.makedirs(output_dir, exist_ok=True)

prompt_suffix = "prompt_extra1"
clip_name = "large14" # 可选 'large14' 或 'big14'

# =================================================================
# 2. 定义你感兴趣的 ROI (类别) 及其对应的文本描述
# =================================================================
ROIS = {
    # --- 原有的 ROI ---
    "places": "a photo of a house, a building, or an indoor scene",
    "bodies": "a photo of a human body or a human limb, not focusing on the face",
    "faces": "a close-up photo of a real human face",
    "words": "a photo of text, words, or numbers",
    
    # --- 新增的 ROI ---
    "animals": "a photo of an animal, like a cat, dog, or bird",
    "food": "a photo of food, a dish, or a meal",
    "vehicles": "a photo of a vehicle, like a car, boat, or airplane",
    "nature": "a photo of a natural landscape without buildings, like a forest, a mountain, or a beach",
    "tools": "a photo of a tool or a man-made object, like a hammer, a phone, or a cup"
}

class_names = list(ROIS.keys())
class_labels = list(ROIS.values())
class_name_to_idx = {name: i for i, name in enumerate(class_names)}

print(f"成功定义 {len(class_names)} 个 ROI 类别: {', '.join(class_names)}")

# =================================================================
# 3. 加载 CLIP 模型
# =================================================================
if clip_name == "large14":
    clip_local = "=Cache/models--openai--clip-vit-large-patch14/snapshots/32bd64288804d66eefd0ccbe215aa642df71cc41"
    from transformers import CLIPModel, CLIPProcessor
    model = CLIPModel.from_pretrained(clip_local).to(device)
    processor = CLIPProcessor.from_pretrained(clip_local)
    print("Loaded CLIP ViT-L/14 model.")
elif clip_name == "big14":
    import open_clip
    model, _, preprocess = open_clip.create_model_and_transforms(
        'ViT-bigG-14',
        pretrained='.cache/modelscope/hub/models/AI-ModelScope/CLIP-ViT-bigG-14-laion2B-39B-b160k/open_clip_pytorch_model.bin',
    )
    model = model.to(device)
    tokenizer = open_clip.get_tokenizer('ViT-bigG-14')
    print("Loaded OpenCLIP ViT-bigG-14 model.")

# =================================================================
# 4. 加载图像数据
# =================================================================
f = h5py.File(hdf5_file, 'r')
coco_images = f['images']
print(f"Loaded {len(coco_images)} images from {hdf5_file}")

# =================================================================
# 5. 使用 CLIP 计算所有图像对所有类别文本的相似度概率 (已修正)
# =================================================================
num_classes = len(class_labels)
prob_file = os.path.join(output_dir, f"coco_images_probs_{num_classes}class_{clip_name}_{prompt_suffix}.pt")

if not os.path.exists(prob_file):
    print(f"Probability file not found. Running CLIP forward for {num_classes} classes...")
    all_probs = []

    with torch.no_grad():
        if clip_name == 'large14':
            # --- large14 (transformers) 推理逻辑 ---
            # 使用 'processor' 来处理文本
            inputs = processor(text=class_labels, return_tensors="pt", padding=True)
            inputs = {k: v.to(device) for k, v in inputs.items()}
            
            for idx in tqdm(range(len(coco_images)), desc=f"CLIP probs (L/14)", total=len(coco_images)):
                image = torch.tensor(coco_images[idx]).unsqueeze(0).to(device)
                # 将图片和文本输入一起传入模型
                out = model(pixel_values=image, **inputs)
                probs = out.logits_per_image.softmax(dim=1)
                all_probs.append(probs.cpu())
                
        elif clip_name == 'big14':
            # --- big14 (open_clip) 推理逻辑 ---
            # 使用 'tokenizer' 来处理文本
            text_tokens = tokenizer(class_labels).to(device)
            text_features = model.encode_text(text_tokens)
            text_features /= text_features.norm(dim=-1, keepdim=True)
            
            for idx in tqdm(range(len(coco_images)), desc=f"CLIP probs (bigG/14)"):
                image = torch.tensor(coco_images[idx]).unsqueeze(0).to(device).float()
                image_features = model.encode_image(image)
                image_features /= image_features.norm(dim=-1, keepdim=True)
                
                logits = (image_features @ text_features.T)
                probs = logits.softmax(dim=-1)
                all_probs.append(probs.cpu())

    all_probs = torch.cat(all_probs, dim=0)
    torch.save(all_probs, prob_file)
    print(f"Saved new probabilities to {prob_file}")
else:
    all_probs = torch.load(prob_file)
    print(f"Loaded existing probabilities from {prob_file}")
    if all_probs.shape[1] != num_classes:
        raise ValueError(f"Error: Loaded probability file has {all_probs.shape[1]} classes, but current config has {num_classes} classes. Please delete the old file '{prob_file}' and re-run.")

# =================================================================
# 6. 为每个 ROI 类别提取 Top-K 图像并保存
# =================================================================
for topk in [100, 200, 300, 500, 1000]:
    for target_class in class_names:
        target_class_idx = class_name_to_idx[target_class]

        top_indices = torch.argsort(all_probs[:, target_class_idx], descending=True)[:topk]
        sort_for_h5, reorder_back = torch.sort(top_indices)
        stimuli_set = torch.tensor(coco_images[sort_for_h5.numpy()])[reorder_back]

        save_name = f"{target_class}_top{topk}_{clip_name}_{prompt_suffix}.pt"
        torch.save(stimuli_set, os.path.join(output_dir, save_name))

        viz_dir = os.path.join(output_dir, f"top{topk}_visualization_{clip_name}_{prompt_suffix}")
        os.makedirs(viz_dir, exist_ok=True)
        for rank, global_idx in enumerate(top_indices[:20]):
            img = to_pil(torch.tensor(coco_images[global_idx.item()]))
            img.save(os.path.join(viz_dir, f"{target_class}-{rank+1:03d}.png"))

        print(f"Done! Stimuli set '{save_name}' has been generated.")

f.close()
print("\nAll tasks completed successfully!")